Skip to content

Conversation

@chaserogo
Copy link

Description

When running with a per_device_batch_size < 1, rather than splitting larger batches from the data loader into the smaller batch sizes for training, the larger batch is truncated. This PR changes it so that instead the larger batch is split into smaller batches which are looped over instead, so that no data is discarded.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@google-cla
Copy link

google-cla bot commented Sep 17, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@aireenmei
Copy link
Collaborator

Thanks for contributing to maxtext! Sorry for the delay in review. Have you completed the CLA? Could you rebase the PR, check the checklist, so we can rerun the tests?

Copy link
Collaborator

@aireenmei aireenmei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a unit test in multihost_dataloading_test.py for this feature?


sharded_iter = self._base_iter()
if self.microbatch_size_to_run:
self.local_iterator = itertools.chain.from_iterable(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be simplified:

def microbatch_generator():
    for global_batch in sharded_iter:
        yield from self.explode_to_micro(global_batch)

if self.microbatch_size_to_run:
  self.local_iterator = microbatch_generator()
else:
  self.local_iterator = sharded_iter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants